package com.jichangxiu.common.utils;

import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.TypeHandlerRegistry;

import java.util.*;
import java.util.stream.Collectors;

@Slf4j
public class SqlUtils {

    // 定义常用的 sql关键字
    public static final String SQL_REGEX = "select |insert |delete |update |drop |count |exec |chr |mid |master |truncate |char |and |declare ";

    // 仅支持字母、数字、下划线、空格、逗号、小数点（支持多个字段排序）
    public static final String SQL_PATTERN = "[a-zA-Z0-9_\\ \\,\\.]+";

    // MyBatis 配置对象
    private static Configuration configuration = null;

    // 定义一个包含需要添加单引号括起来的参数类型集合
    private static final Set<String> NEED_BRACKETS = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("String", "Date", "Time", "LocalDate", "LocalTime", "LocalDateTime", "BigDecimal", "Timestamp")));

    // 检查字符，防止注入绕过
    public static String escapeOrderBySql(String value) {
        if (StrUtil.isNotEmpty(value) && !isValidOrderBySql(value))
            throw new RuntimeException("【SqlUtils】参数不符合规范");
        return value;
    }

    // 验证 order by 语法是否符合规范
    public static boolean isValidOrderBySql(String value) {
        return value.matches(SQL_PATTERN);
    }

    // SQL关键字检查
    public static void filterKeyword(String value) {
        if (StrUtil.isEmpty(value)) return;
        List<String> sqlKeywords = StrUtil.split(SQL_REGEX, "\\|");
        for (String sqlKeyword : sqlKeywords)
            if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1)
                throw new RuntimeException("【SqlUtils】参数存在【SQL】注入风险");
    }

    // 获取 SQL 语句
    public static String getMyBatisPlusSql(Object target) {
        try {
            // 获取 StatementHandler 对象
            StatementHandler statementHandler = (StatementHandler) target;
            // 获取 BoundSql 对象
            BoundSql boundSql = statementHandler.getBoundSql();
            if (configuration == null) {
                // 通过反射获取 Configuration 对象
                final ParameterHandler parameterHandler = statementHandler.getParameterHandler();
                configuration = (Configuration) FieldUtils.readField(parameterHandler, "configuration", true);
            }
            // 格式化 SQL 语句并返回
            return formatMyBatisPlusSql(boundSql, configuration);
        } catch (Exception ex) {
            // 异常处理，打印警告日志
            throw new RuntimeException("【SqlUtils】获取【SQL】语句失败", ex);
        }
    }

    // 格式化 SQL 语句
    public static String formatMyBatisPlusSql(BoundSql boundSql, Configuration configuration) {
        // 获取原始 SQL 语句
        String sql = boundSql.getSql();
        // 获取参数映射列表
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        // 获取参数对象
        Object parameterObject = boundSql.getParameterObject();
        // 判断是否为空
        if (StringUtils.isEmpty(sql) || Objects.isNull(configuration)) {
            return "";
        }
        // 获取 TypeHandlerRegistry 对象
        TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
        // 移除 SQL 字符串中的空格、换行符等
        sql = sql.replaceAll("[\n\r ]+", " ");
        // 过滤掉输出参数的参数映射
        if (parameterMappings == null) {
            return sql;
        }
        parameterMappings = parameterMappings.stream().filter(it -> it.getMode() != ParameterMode.OUT).collect(Collectors.toList());
        // 使用 StringBuilder 保存格式化后的 SQL
        final StringBuilder result = new StringBuilder(sql);
        // 解析问号并替换参数
        for (int i = result.length(); i > 0; i--) {
            if (result.charAt(i - 1) != '?') {
                continue;
            }
            ParameterMapping parameterMapping = parameterMappings.get(parameterMappings.size() - 1);
            Object value;
            String propertyName = parameterMapping.getProperty();
            // 判断绑定的附加参数中是否有对应的属性名
            if (boundSql.hasAdditionalParameter(propertyName)) {
                value = boundSql.getAdditionalParameter(propertyName);
            } else if (parameterObject == null) {
                value = null;
            } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                value = parameterObject;
            } else {
                // 使用 MetaObject 获取属性值
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                value = metaObject.getValue(propertyName);
            }
            if (value != null) {
                // 判断参数类型，如果是需要添加括号的类型，则添加单引号
                String type = value.getClass().getSimpleName();
                if (NEED_BRACKETS.contains(type)) {
                    result.replace(i - 1, i, "'" + value + "'");
                } else {
                    result.replace(i - 1, i, value.toString());
                }
            } else {
                // 参数值为空时，替换为 "null"
                result.replace(i - 1, i, "null");
            }
            // 移除已处理的参数映射
            parameterMappings.remove(parameterMappings.size() - 1);
        }
        return result.toString();
    }

}
